{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 环境配置"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 安装conda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-08-18 20:22:58--  https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
      "正在解析主机 repo.continuum.io (repo.continuum.io)... 104.18.201.79, 104.18.200.79, 2606:4700::6812:c84f, ...\n",
      "正在连接 repo.continuum.io (repo.continuum.io)|104.18.201.79|:443... 已连接。\n",
      "已发出 HTTP 请求，正在等待回应... 301 Moved Permanently\n",
      "位置：https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh [跟随至新的 URL]\n",
      "--2020-08-18 20:23:02--  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
      "正在解析主机 repo.anaconda.com (repo.anaconda.com)... 104.16.130.3, 104.16.131.3, 2606:4700::6810:8303, ...\n",
      "正在连接 repo.anaconda.com (repo.anaconda.com)|104.16.130.3|:443... 已连接。\n",
      "已发出 HTTP 请求，正在等待回应... 200 OK\n",
      "长度：93052469 (89M) [application/x-sh]\n",
      "正在保存至: “Miniconda3-latest-Linux-x86_64.sh”\n",
      "\n",
      "16% [=====>                                 ] 15,456,713   246KB/s 剩余 3m 48s   ^C\n"
     ]
    }
   ],
   "source": [
    "# https://blog.csdn.net/weixin_43840215/article/details/89599559\n",
    "!wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "!chmod 777 Miniconda3-latest-Linux-x86_64.sh\n",
    "!sh Miniconda3-latest-Linux-x86_64.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/bin/bash: conda: 未找到命令\n",
      "/bin/bash: conda: 未找到命令\n",
      "/bin/bash: conda: 未找到命令\n",
      "/bin/bash: conda: 未找到命令\n",
      "/bin/bash: conda: 未找到命令\n",
      "/bin/bash: conda: 未找到命令\n"
     ]
    }
   ],
   "source": [
    "# 上面链接都是https的。然而我在下载的时候报错了。应该把s删掉\n",
    "!conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/\n",
    "!conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/\n",
    "!conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/\n",
    "!conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/\n",
    "!conda config --set show_channel_urls yes \n",
    "!conda config --get channels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "# https://blog.csdn.net/zzq060143/article/details/88042075\n",
    "# 请仔细参考这个，添加镜像\n",
    "!conda install pytorch torchvision cudatoolkit=10.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/root/miniconda3/bin/python'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sys\n",
    "\n",
    "sys.executable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "# 由于我是使用conda安装的，默认使用的是conda的base环境。\n",
    "# 所以其他环境是使用不了torch的\n",
    "# 先要激活conda环境，并且退出当前的virtualenv\n",
    "\n",
    "!deactivate # 在当前环境直接输入这个指令即可\n",
    "!source miniconda3/bin/activate\n",
    "!conda jupyter notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.DataFrame(np.arange(8).reshape(4,2),columns=['a','b'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.DataFrame(np.arange(4,12).reshape(4,2),columns=['c','d'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 可以使用这种方法同时修改训练集和测试集\n",
    "all_data = [data, test]\n",
    "for i in all_data:\n",
    "    i.iloc[:,1] = np.arange(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看电脑的显卡能不能使用\n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 两个重要函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dir(torch) dir可以查看有哪些可用的函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# help(torch.cuda) help告诉我们如何使用这个函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch的学习"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 1, 2, 3],\n",
       "       [4, 5, 6, 7]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# torch类型数据\n",
    "# numpy类型数据\n",
    "\n",
    "np_data = np.arange(8).reshape(2,4)\n",
    "np_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据类型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在torch中，如果是一个数字，就是0维，**在torch中没有string类型的数据**。如果要表示字符串类型，可以使用one-hot方法\n",
    "\n",
    "然而在实际的生活中，英语单词有上万多个，我们单独靠one-hot，会使得矩阵非常稀疏。而且不能表达出单词的近似程度。\n",
    "\n",
    "所以one-hot编码是有缺陷的，所以人们开发出来其他的方式来编码，编码叫Embedding， 里面有w2v， glove方式"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "floattensor  32\n",
    "\n",
    "doubletensor 64\n",
    "\n",
    "longtensor 64\n",
    "\n",
    "inttensor 32\n",
    "\n",
    "bytetensor 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'torch.FloatTensor'"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.randn(2,3)\n",
    "a.type() # 返回的是具体的数据类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Tensor"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 返回的是基本的数据类型\n",
    "type(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 判断诗句类型是不是一致\n",
    "isinstance(a, torch.FloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "deletable": false,
    "editable": false,
    "run_control": {
     "frozen": true
    }
   },
   "outputs": [],
   "source": [
    "# 将数据放到GPU的时候数据类型但是又会发生变化\n",
    "\n",
    "isinstance(data, torch.cuda.FloatTensor)\n",
    "\n",
    "data.cuda() # 将数据放到GPU上\n",
    "\n",
    "isinstance(data, torch.cuda.FloatTensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "0维的有什么用呢？通常loss就是0维的，也就是标量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.)"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 纬度为0的标量，维度为0就是标量\n",
    "torch.tensor(1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor(1.).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([])"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor(1.).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor(1.).dim()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "维度为1的张量有什么用呢？\n",
    "\n",
    "在神经元中，我们通常要设置偏置。ax+b中b就是。我们用一维张量表示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1])"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 1 dim []\n",
    "torch.tensor([1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1])"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor([1]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "only one element tensors can be converted to Python scalars",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-95-6daa886b8cad>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m: only one element tensors can be converted to Python scalars"
     ]
    }
   ],
   "source": [
    "torch.tensor([1,1]).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.0798e-24,  4.5891e-41, -1.5148e+00],\n",
       "        [ 3.0624e-41,         nan,  0.0000e+00]])"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2 dim [[],[]],多个列表嵌套就是2维\n",
    "\n",
    "a = torch.FloatTensor(2,3)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.0798e-24,  4.5891e-41, -1.3108e+00]])"
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = torch.FloatTensor(1,3) # 注意这是随机生成的\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-1.0798e-24,  4.5891e-41, -1.5225e+00],\n",
       "         [ 3.0624e-41,  4.4842e-44,  0.0000e+00]]])"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 3 dim [[[]],[[]]]\n",
    "# 3 dim有什么用呢？他们的使用场景非常广泛\n",
    "# 比如在nlp中，一句话有10个单词，每个单词用向量表示这句话就可以用2维tensor表示[10, 100]/假如每个单词用100个表示\n",
    "# 假如现在有20句话，用3 dim表示[20,10,100]\n",
    "c = torch.FloatTensor(1,2,3)\n",
    "c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.0798e-24,  4.5891e-41, -1.5225e+00],\n",
       "        [ 3.0624e-41,  4.4842e-44,  0.0000e+00]])"
      ]
     },
     "execution_count": 107,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-1.0798e-24,  4.5891e-41, -1.5225e+00])"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 更高维 4dim，5dim\n",
    "# [2,3,29,29]\n",
    "# 上面2表示照片数量，3表示每个照片颜色通道RGB，29*29表示长宽\n",
    "# 注意一个维度，只能用来描述一种特征"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 未初始化的tensor，不要直接使用未初始化的数据，输入到其他地方\n",
    "# 未初始化只是为了申请内存，然后赋值给他。\n",
    "# 如果直接将未初始化的数据喂到其他地方。因为值是随机的。有些值是无穷大。容易出现问题\n",
    "\n",
    "torch.FloatTensor(2,2,2) # 未初始化的数据，其内部的值是随机的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.8526, 0.3177, 0.6964],\n",
       "        [0.0450, 0.3220, 0.9346],\n",
       "        [0.5983, 0.3297, 0.9744]])"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 随机初始化\n",
    "torch.rand(3,3) # 初始化的值是在0-1之间"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.4837, 0.7445, 0.2592],\n",
       "         [0.3884, 0.1026, 0.6049],\n",
       "         [0.4994, 0.7637, 0.4178]],\n",
       "\n",
       "        [[0.3952, 0.8937, 0.4947],\n",
       "         [0.4957, 0.4705, 0.9129],\n",
       "         [0.7484, 0.5479, 0.8227]],\n",
       "\n",
       "        [[0.6375, 0.6115, 0.6347],\n",
       "         [0.4693, 0.2609, 0.5884],\n",
       "         [0.2203, 0.5070, 0.4755]]])"
      ]
     },
     "execution_count": 114,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.rand(3,3,3)\n",
    "c = torch.rand_like(a) # 读取一个tensor的shape，然后传给rand like。like相似\n",
    "c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[2, 9],\n",
       "        [3, 9]])"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.randint(1,10,[2,2]) # 在最大和最小值区间生成随机数，第三个参数是指定shape，以列表的形式传入，不包括最大值10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.8897,  1.5054,  0.1210],\n",
       "        [-0.1361,  0.3104, -0.7207],\n",
       "        [ 1.0253, -1.2426,  0.5322]])"
      ]
     },
     "execution_count": 118,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.randn(3,3) # 生成一个正太分布的数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.],\n",
       "        [0., 0., 0.]])"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.full([2,3],0.0,) # 好像只能生成float类型的数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(3.)"
      ]
     },
     "execution_count": 129,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.full([],3.) # 创建一个标量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 134,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.arange(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])"
      ]
     },
     "execution_count": 135,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.linspace(0,10,11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.0000,  1.1111,  2.2222,  3.3333,  4.4444,  5.5556,  6.6667,  7.7778,\n",
       "         8.8889, 10.0000])"
      ]
     },
     "execution_count": 137,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.linspace(0,10,10) # 等份切"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 1.],\n",
       "        [1., 1., 1.],\n",
       "        [1., 1., 1.]])"
      ]
     },
     "execution_count": 138,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.ones(3,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 0., 0.],\n",
       "        [0., 1., 0.],\n",
       "        [0., 0., 1.]])"
      ]
     },
     "execution_count": 139,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.eye(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.],\n",
       "        [0., 0., 0.],\n",
       "        [0., 0., 0.]])"
      ]
     },
     "execution_count": 140,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.zeros(3,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2, 4, 3, 1, 8, 0, 7, 6, 9, 5])"
      ]
     },
     "execution_count": 141,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.randperm(10) # 生成一个随机的序列"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 索引和切片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.9975, 0.8395],\n",
       "         [0.4927, 0.9900],\n",
       "         [0.3013, 0.9192]],\n",
       "\n",
       "        [[0.7212, 0.0430],\n",
       "         [0.5839, 0.6931],\n",
       "         [0.6613, 0.2377]]])"
      ]
     },
     "execution_count": 145,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = torch.rand(2,3,2)\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.9975, 0.8395],\n",
       "        [0.4927, 0.9900],\n",
       "        [0.3013, 0.9192]])"
      ]
     },
     "execution_count": 146,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[0] # 索引的是第0个3，2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9975, 0.8395])"
      ]
     },
     "execution_count": 149,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.8395)"
      ]
     },
     "execution_count": 148,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[0,0,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.9975, 0.8395],\n",
      "        [0.4927, 0.9900],\n",
      "        [0.3013, 0.9192]])\n",
      "tensor([[0.7212, 0.0430],\n",
      "        [0.5839, 0.6931],\n",
      "        [0.6613, 0.2377]])\n"
     ]
    }
   ],
   "source": [
    "for each in data: # 居然可以遍历\n",
    "    print(each)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.9975, 0.8395],\n",
       "         [0.4927, 0.9900],\n",
       "         [0.3013, 0.9192]]])"
      ]
     },
     "execution_count": 153,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 切片\n",
    "data[:1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 154,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[:1].dim() # 依旧是三维的 # 这里是对第一个通道索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.9975, 0.8395]]])"
      ]
     },
     "execution_count": 157,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[:1,:1] # 还可以对第二个通道索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.3013, 0.9192]]])"
      ]
     },
     "execution_count": 158,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[:1,-1:] # -1开始表示只选取末尾"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.9975, 0.8395],\n",
       "         [0.4927, 0.9900],\n",
       "         [0.3013, 0.9192]]])"
      ]
     },
     "execution_count": 160,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[:1:2,:,:] # 甚至可以隔行采样"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.9250, 0.9249, 0.6874],\n",
       "         [0.3065, 0.2696, 0.8503],\n",
       "         [0.0577, 0.7079, 0.6500]],\n",
       "\n",
       "        [[0.1635, 0.2214, 0.2830],\n",
       "         [0.1326, 0.0176, 0.5144],\n",
       "         [0.7082, 0.1147, 0.5520]],\n",
       "\n",
       "        [[0.0726, 0.4157, 0.1398],\n",
       "         [0.6775, 0.5979, 0.3828],\n",
       "         [0.5380, 0.3950, 0.5863]]])"
      ]
     },
     "execution_count": 162,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = torch.rand(3,3,3)\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "index_select() received an invalid combination of arguments - got (int, list), but expected one of:\n * (name dim, Tensor index)\n      didn't match because some of the arguments have invalid types: (\u001b[31;1mint\u001b[0m, \u001b[31;1mlist\u001b[0m)\n * (int dim, Tensor index)\n      didn't match because some of the arguments have invalid types: (\u001b[32;1mint\u001b[0m, \u001b[31;1mlist\u001b[0m)\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-163-0c20a8b73951>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_select\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m: index_select() received an invalid combination of arguments - got (int, list), but expected one of:\n * (name dim, Tensor index)\n      didn't match because some of the arguments have invalid types: (\u001b[31;1mint\u001b[0m, \u001b[31;1mlist\u001b[0m)\n * (int dim, Tensor index)\n      didn't match because some of the arguments have invalid types: (\u001b[32;1mint\u001b[0m, \u001b[31;1mlist\u001b[0m)\n"
     ]
    }
   ],
   "source": [
    "data.index_select(0,[0,2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.,  2.,  3.,  4.],\n",
       "        [ 5.,  6.,  7.,  8.],\n",
       "        [ 9., 10., 11., 12.]])"
      ]
     },
     "execution_count": 172,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# torch.select_index用法\n",
    "test = torch.linspace(1,12, 12).view(3,4)\n",
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.,  2.],\n",
       "        [ 5.,  6.],\n",
       "        [ 9., 10.]])"
      ]
     },
     "execution_count": 175,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.index_select(1, torch.tensor([0,1])) # https://blog.csdn.net/g_blink/article/details/102854188\n",
    "\n",
    "# 参数说明：第一个参数0表示按行索引，1表示按列索引。第二个参数传入一个tensor，表示传入的index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.,  2.,  3.,  4.],\n",
       "        [ 5.,  6.,  7.,  8.],\n",
       "        [ 9., 10., 11., 12.]])"
      ]
     },
     "execution_count": 178,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test[...] # 三个句号表示所有维度都取。:::替代这个"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([5., 6., 7., 8.])"
      ]
     },
     "execution_count": 177,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test[1,...] # 可以先给定一个维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.,  3.],\n",
       "        [ 5.,  7.],\n",
       "        [ 9., 11.]])"
      ]
     },
     "execution_count": 179,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test[...,::2] # 也可以这样混合着使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.3531,  0.2694, -0.6561,  0.3088],\n",
       "        [ 0.9426, -0.3599, -0.1596, -1.0377],\n",
       "        [-0.4034,  1.9707,  0.0576,  0.8735]])"
      ]
     },
     "execution_count": 181,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 设置阈值，将大于阈值的位置设置为1\n",
    "\n",
    "x = torch.randn(3,4)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ True, False, False, False],\n",
       "        [ True, False, False, False],\n",
       "        [False,  True, False,  True]])"
      ]
     },
     "execution_count": 182,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.ge(0.5) # 设置阈值，大于阈值的返回true"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tensor维度增加"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# view，reshape\n",
    "# 这里举例为minist手写数据集。4表示4张图片，3表示颜色只有RGB。28*28表示长宽\n",
    "test = torch.rand(4,3,28,28) # rand生成的数是在0-1之间"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 28, 28])"
      ]
     },
     "execution_count": 193,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 198,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第一个view操作，合并维度\n",
    "\n",
    "res = test.view(4,3*28*28) # 把后面几个维度合并，特别适合神经网络中的全连接层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 196,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 2352])"
      ]
     },
     "execution_count": 196,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1116, 0.5150, 0.2852,  ..., 0.7692, 0.9233, 0.0275],\n",
       "        [0.7971, 0.8134, 0.7538,  ..., 0.2127, 0.4565, 0.0840],\n",
       "        [0.6096, 0.0647, 0.9465,  ..., 0.7477, 0.9180, 0.1116],\n",
       "        [0.2819, 0.0272, 0.8308,  ..., 0.9183, 0.9931, 0.2143]])"
      ]
     },
     "execution_count": 200,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = res.view(4,3,28,28) # 数据的恢复。可以将数据再展回去"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [],
   "source": [
    "# squeeze,unsqueeze\n",
    "\n",
    "test = torch.rand(4,1,28,28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 4, 1, 28, 28])"
      ]
     },
     "execution_count": 203,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.unsqueeze(0).shape # 在0的前面插入一个维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 204,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 1, 28, 28, 1])"
      ]
     },
     "execution_count": 204,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.unsqueeze(-1).shape # 在-1的后面插入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 213,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1659, 0.4205],\n",
       "        [0.4678, 0.1701]])"
      ]
     },
     "execution_count": 213,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = torch.rand(2,2)\n",
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 214,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2])"
      ]
     },
     "execution_count": 214,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 217,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.1659, 0.4205],\n",
       "         [0.4678, 0.1701]]])"
      ]
     },
     "execution_count": 217,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.unsqueeze_(0) # 只是增加了一个维度，并不改变数据本身。正数表示在0之前添加\n",
    "# 加下划线表示会替换原来的变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 2, 2])"
      ]
     },
     "execution_count": 218,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 208,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 2])"
      ]
     },
     "execution_count": 208,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = torch.tensor([1,2])\n",
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 211,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 211,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.dim()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 209,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1],\n",
       "        [2]])"
      ]
     },
     "execution_count": 209,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.unsqueeze(-1) # 这样展开为列向量"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "维度变化的一般用处"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 223,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = torch.rand(3,3,12,2)\n",
    "# 假如现在想在12这个维度加一个偏置\n",
    "bias = torch.rand(12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12])"
      ]
     },
     "execution_count": 224,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bias.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 225,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 如果直接相加的话，维度不同是不能直接相加的\n",
    "bias = bias.unsqueeze(0).unsqueeze(0).unsqueeze(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1, 12, 1])"
      ]
     },
     "execution_count": 226,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bias.shape # 后面会介绍1->3的方法。方法名叫扩张"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tensor维度删减"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "metadata": {},
   "outputs": [],
   "source": [
    "# squeeze\n",
    "test = torch.rand(2,2,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 228,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2, 1])"
      ]
     },
     "execution_count": 228,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 232,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.9808],\n",
       "         [0.9460]],\n",
       "\n",
       "        [[0.1491],\n",
       "         [0.1736]]])"
      ]
     },
     "execution_count": 232,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1491],\n",
       "        [0.1736]])"
      ]
     },
     "execution_count": 231,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test[1,...]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.9808, 0.9460],\n",
       "        [0.1491, 0.1736]])"
      ]
     },
     "execution_count": 229,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.squeeze() # 默认是把为1的维度去掉，并不改变数据本身"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 234,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2, 1])"
      ]
     },
     "execution_count": 234,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.squeeze(0).shape # 0位置不是1，所以不会删除维度。没有报错，返回的不变"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 235,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2])"
      ]
     },
     "execution_count": 235,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.squeeze(-1).shape # 为1就删除了"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 维度拓展"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 246,
   "metadata": {},
   "outputs": [],
   "source": [
    "# expend, repeat\n",
    "test = torch.rand(1,32,1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 247,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 32, 1, 1])"
      ]
     },
     "execution_count": 247,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 248,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 假如现在要拓展为4，32，14，14\n",
    "\n",
    "test = test.expand(4,32,14,14)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The expanded size of the tensor (5) must match the existing size (4) at non-singleton dimension 0.  Target sizes: [5, 32, 14, 14].  Tensor sizes: [4, 32, 14, 14]",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-250-4bbfccddf6ae>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 注意，只能维度值为1的进行扩张，其他维度并不能扩张\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m: The expanded size of the tensor (5) must match the existing size (4) at non-singleton dimension 0.  Target sizes: [5, 32, 14, 14].  Tensor sizes: [4, 32, 14, 14]"
     ]
    }
   ],
   "source": [
    "test.expand(5,32,14,14) # 注意，只能维度值为1的进行扩张，其他维度并不能扩张，3-5就不能，只能1-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 253,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 32, 14, 14])"
      ]
     },
     "execution_count": 253,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.expand(-1,32,-1,-1).shape # -1表示这个维度不变"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 254,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 32, 14, 14])"
      ]
     },
     "execution_count": 254,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# repeat\n",
    "\n",
    "test = torch.rand(1,32,1,1)\n",
    "test.repeat(4,1,14,14).shape # repeat4表示重复4次。32保持不变。所以是1.这里不推荐使用repeat"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 转置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 255,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = torch.rand(2,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 256,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3665, 0.5724],\n",
       "        [0.1173, 0.2822]])"
      ]
     },
     "execution_count": 256,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.t()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 257,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "t() expects a tensor with <= 2 dimensions, but self is 3D",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-257-8fbe64eb0613>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mtest\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m: t() expects a tensor with <= 2 dimensions, but self is 3D"
     ]
    }
   ],
   "source": [
    "test = torch.rand(1,2,3) # 注意只有在二维的情况下才能转置\n",
    "test.t()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 交换维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 258,
   "metadata": {},
   "outputs": [],
   "source": [
    "# transpose\n",
    "\n",
    "test = torch.rand(1,32,1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 259,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([32, 1, 1, 1])"
      ]
     },
     "execution_count": 259,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.transpose(0,1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 260,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([32, 1, 1, 1])"
      ]
     },
     "execution_count": 260,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# permuter\n",
    "\n",
    "test.permute(1,2,3,0).shape # 直接传入位置即可"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 张量相加"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在假设有个场景，tensor+scaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 261,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = torch.rand(2,12,4) # 2个班级，每班级12名学生，每学生学4门课程得分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 263,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1])"
      ]
     },
     "execution_count": 263,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 现在假如我要给学生的每门课程加3分\n",
    "# 也就是实现张量加标量\n",
    "\n",
    "bias = torch.tensor([3])\n",
    "bias.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 265,
   "metadata": {},
   "outputs": [],
   "source": [
    "bias = bias.unsqueeze(0).unsqueeze(0).expand(2,12,4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 267,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 12, 4])"
      ]
     },
     "execution_count": 267,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bias.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 268,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3]],\n",
       "\n",
       "        [[3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3],\n",
       "         [3, 3, 3, 3]]])"
      ]
     },
     "execution_count": 268,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[3.5926, 3.6585, 3.5626, 3.7852],\n",
       "         [3.3929, 3.1034, 3.3518, 3.7314],\n",
       "         [3.2450, 3.1379, 3.8641, 3.9138],\n",
       "         [3.4566, 3.7594, 3.2158, 3.7839],\n",
       "         [3.2208, 3.7148, 3.0081, 3.4197],\n",
       "         [3.4276, 3.0490, 3.2225, 3.5248],\n",
       "         [3.9596, 3.6256, 3.1139, 3.5801],\n",
       "         [3.4581, 3.6182, 3.0749, 3.3375],\n",
       "         [3.3139, 3.8552, 3.4132, 3.8696],\n",
       "         [3.9641, 3.3554, 3.6175, 3.8541],\n",
       "         [3.1284, 3.9316, 3.5702, 3.9627],\n",
       "         [3.6021, 3.9347, 3.7362, 3.8569]],\n",
       "\n",
       "        [[3.3054, 3.8866, 3.6174, 3.2519],\n",
       "         [3.7663, 3.4736, 3.9203, 3.5704],\n",
       "         [3.3170, 3.5279, 3.8722, 3.3639],\n",
       "         [3.3188, 3.2035, 3.3834, 3.0632],\n",
       "         [3.5136, 3.0656, 3.6519, 3.5914],\n",
       "         [3.4435, 3.0767, 3.5103, 3.5679],\n",
       "         [3.8727, 3.5434, 3.7100, 3.6311],\n",
       "         [3.9628, 3.6173, 3.2581, 3.0761],\n",
       "         [3.0588, 3.5210, 3.5379, 3.3656],\n",
       "         [3.5046, 3.8208, 3.3258, 3.6775],\n",
       "         [3.1048, 3.6571, 3.4247, 3.7635],\n",
       "         [3.3865, 3.0250, 3.5494, 3.1199]]])"
      ]
     },
     "execution_count": 269,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test + bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在假设我们只给最低维度某个科目加分怎么办"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4])"
      ]
     },
     "execution_count": 270,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bias = torch.tensor([0,0,2,0])\n",
    "bias.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 271,
   "metadata": {},
   "outputs": [],
   "source": [
    "bias = bias.unsqueeze(0).unsqueeze(0).expand(2,12,4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 272,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0]],\n",
       "\n",
       "        [[0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0],\n",
       "         [0, 0, 2, 0]]])"
      ]
     },
     "execution_count": 272,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 合并拼接"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cat\n",
    "\n",
    "test1 = torch.rand(3,2,4)\n",
    "test2 = torch.rand(4,2,4)\n",
    "test3 = torch.rand(3,2,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 2, 4])"
      ]
     },
     "execution_count": 278,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cat([test1,test2], dim=0).shape # 按行拼接"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 2, 5])"
      ]
     },
     "execution_count": 279,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cat([test1,test3], dim=2).shape # cat可以选择你要的维度，dim=2表示最后面这个维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 284,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 32, 3])"
      ]
     },
     "execution_count": 284,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# stack\n",
    "# stack会按指定的维度，然后创建一个新的维度\n",
    "\n",
    "\n",
    "# 考虑下面的场景/一位老师统计了两个班级的成绩。现在要将两个班级的成绩合并\n",
    "# 新增加的维度表示班级，0表示第一个班级，\n",
    "# stack只能使用在两个tensor的shape相同的情况下\n",
    "test1 = torch.rand(32,3)\n",
    "test2 = torch.rand(32,3)\n",
    "\n",
    "torch.stack([test1,test2],dim=0).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据拆分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 293,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 2, 1])"
      ]
     },
     "execution_count": 293,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# split\n",
    "a1 = torch.rand(2,2,1)\n",
    "a2 = torch.rand(3,2,1)\n",
    "\n",
    "test = torch.cat([a1,a2],dim=0)\n",
    "test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 294,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 把原先cat的数据拆分出去\n",
    "# 拆分成两个，按第0维度拆分，按2，3拆分\n",
    "b1, b2 = test.split([2,3],dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 296,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2, 1])"
      ]
     },
     "execution_count": 296,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 299,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 299,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 比较两个是否相等，\n",
    "torch.any(torch.eq(a1,b1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 300,
   "metadata": {},
   "outputs": [],
   "source": [
    "# chunk\n",
    "\n",
    "\n",
    "# chunk 是按数量拆分。4/2 = 2这种类型\n",
    "test = torch.rand(4,3,2)\n",
    "a, b = test.chunk(2,dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 301,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3, 2])"
      ]
     },
     "execution_count": 301,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 将array转化成tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1, 2, 3],\n",
       "        [4, 5, 6, 7]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    " torch_data = torch.from_numpy(np_data)\n",
    "torch_data # 将np数据转化成tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Tensor"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(torch_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 将tensor转化成array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor_to_array = torch_data.numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一些常用函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "data1 = [1,9]\n",
    "data2 = [[1,2],[3,4]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "mean(): argument 'input' (position 1) must be Tensor, not list",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-24-3da3b29173b1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m: mean(): argument 'input' (position 1) must be Tensor, not list"
     ]
    }
   ],
   "source": [
    "torch.mean(data1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.0"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(data1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在torch和numpy中，前者最好先将数据转化成tensor的形式，后者最好转化成array的形式\n",
    "# 这样转换后，才好通过运算\n",
    "array1 = np.array(data1)\n",
    "array2 = np.array(data2)\n",
    "tensor1 = torch.FloatTensor(data1)\n",
    "tensor2 = torch.FloatTensor(data2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(3.)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.mean(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.8415,  0.4121, -0.8415,  0.1411])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sin(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 9., 1., 3.])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.abs(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tensor的基本运算"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 四则运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.rand(3,2)\n",
    "b = torch.rand(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 312,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 312,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.dim()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 313,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 313,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.dim()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 314,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.9876, 0.4394],\n",
       "        [0.6017, 0.8031],\n",
       "        [0.7488, 1.0863]])"
      ]
     },
     "execution_count": 314,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a + b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 315,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.9876, 0.4394],\n",
       "        [0.6017, 0.8031],\n",
       "        [0.7488, 1.0863]])"
      ]
     },
     "execution_count": 315,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 可以看到两个add和+结果是一样的，这里建议直接使用+\n",
    "torch.add(a,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 316,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 为什么两个维度不一样，也可以直接相加呢？\n",
    "# a的最低维度，也就是最右边是2\n",
    "# b的最低维度也是2.shape是一样的\n",
    "# 这是pytorch的broadcast机制。会自动拓展。使用这个机制，可以省去很多的内存\n",
    "# 直接使用expand的话，会消耗很多空间"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 317,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.1547,  0.2538],\n",
       "        [-0.2312,  0.6175],\n",
       "        [-0.0841,  0.9008]])"
      ]
     },
     "execution_count": 317,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a - b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 318,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2379, 0.0322],\n",
       "        [0.0772, 0.0659],\n",
       "        [0.1384, 0.0922]])"
      ]
     },
     "execution_count": 318,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a * b # 相同位置进行相乘"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 319,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.3716,  3.7352],\n",
       "        [ 0.4449,  7.6554],\n",
       "        [ 0.7980, 10.7079]])"
      ]
     },
     "execution_count": 319,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a / b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### tensor乘法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 320,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.5968, 1.6098],\n",
       "        [1.8018, 1.3237],\n",
       "        [1.4462, 0.8854]])"
      ]
     },
     "execution_count": 320,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 注意这里的乘法是矩阵的乘法\n",
    "\n",
    "# 矩阵的乘法有mm和matmul两个接口\n",
    "\n",
    "a = torch.rand(3,4)\n",
    "b = torch.rand(4,2)\n",
    "\n",
    "torch.mm(a,b) # mm只适用于2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 321,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.5968, 1.6098],\n",
       "        [1.8018, 1.3237],\n",
       "        [1.4462, 0.8854]])"
      ]
     },
     "execution_count": 321,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# matmul适用于多维，这里推荐使用matmul\n",
    "# 等价于matmul\n",
    "a @ b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 利用乘法进行降维"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 10])"
      ]
     },
     "execution_count": 324,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.rand(100,219)\n",
    "b = torch.rand(10,219)\n",
    "\n",
    "(a @ b.t()).shape # 有什么用处呢？全连接层输入是100， 最后输出神经元是10.这就是降维过程"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 多维的矩阵相乘"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 327,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 20, 32])"
      ]
     },
     "execution_count": 327,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 这里dim 1 一个是3，一个是1.维度不一样，但是也可以运算，是使用了broadcast机制\n",
    "# 1会复制数据到3维，然后再相乘。\n",
    "\n",
    "a = torch.rand(4,3,20,64)\n",
    "b = torch.rand(4,1,64,32)\n",
    "\n",
    "(a@b).shape  # 多维矩阵，可以要求最后两维的数据变化。但高维的要不变。后面两维乘法相当于2维乘法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 次方运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 328,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2450, 0.4258],\n",
       "        [0.6737, 0.0015]])"
      ]
     },
     "execution_count": 328,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.rand(2,2)\n",
    "a**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 329,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.7035, 0.8078],\n",
       "        [0.9060, 0.1963]])"
      ]
     },
     "execution_count": 329,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.sqrt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 对数运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 334,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[2.7183, 2.7183],\n",
       "        [2.7183, 2.7183]])"
      ]
     },
     "execution_count": 334,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.exp(torch.full([2,2],1.0)) # e的1次方"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 337,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.3863, 1.3863, 1.3863],\n",
       "        [1.3863, 1.3863, 1.3863],\n",
       "        [1.3863, 1.3863, 1.3863]])"
      ]
     },
     "execution_count": 337,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.log(torch.full([3,3],4.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tensor裁剪"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 338,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clamp函数\n",
    "\n",
    "test = torch.rand(4,4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 339,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = test*14"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 341,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 4.2245, 13.1836,  1.0374,  7.9265],\n",
       "        [ 4.7323,  3.0531,  8.0399, 11.4777],\n",
       "        [ 2.6512,  2.6686,  2.9348,  8.7696],\n",
       "        [11.4041, 12.8540,  2.7486,  9.6565]])"
      ]
     },
     "execution_count": 341,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 343,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(13.1836)"
      ]
     },
     "execution_count": 343,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 344,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.0374)"
      ]
     },
     "execution_count": 344,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 342,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10.0000, 13.1836, 10.0000, 10.0000],\n",
       "        [10.0000, 10.0000, 10.0000, 11.4777],\n",
       "        [10.0000, 10.0000, 10.0000, 10.0000],\n",
       "        [11.4041, 12.8540, 10.0000, 10.0000]])"
      ]
     },
     "execution_count": 342,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.clamp(10) # 只传入10，表示裁剪最小值。小于10的都赋值为10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 345,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 4.2245, 10.0000,  1.0374,  7.9265],\n",
       "        [ 4.7323,  3.0531,  8.0399, 10.0000],\n",
       "        [ 2.6512,  2.6686,  2.9348,  8.7696],\n",
       "        [10.0000, 10.0000,  2.7486,  9.6565]])"
      ]
     },
     "execution_count": 345,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.clamp(0,10) # 设置一个阈值，在0-10之间，小于0的都会赋值为0，大于10的都会赋值为10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 矩阵相乘"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 7, 10],\n",
       "       [15, 22]])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "array2.dot(array2) # 矩阵的乘法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([19, 39])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "array2.dot(array1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1, 18],\n",
       "       [ 3, 36]])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "array1*array2 # 甚至可以按位运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 9.])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2.],\n",
       "        [3., 4.]])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(82.)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor1.dot(tensor1) # 1*1+9*9"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch的属性统计"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 范数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 347,
   "metadata": {},
   "outputs": [],
   "source": [
    "# norm\n",
    "# 可以自己设置1，2和维度\n",
    "\n",
    "a = torch.full([8],1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 348,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = a.view(2,4)\n",
    "c = a.view(2,2,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 349,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(8.)"
      ]
     },
     "execution_count": 349,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.norm(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 350,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(8.)"
      ]
     },
     "execution_count": 350,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.norm(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 351,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2.8284)"
      ]
     },
     "execution_count": 351,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.norm(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 352,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2.8284)"
      ]
     },
     "execution_count": 352,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.norm(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 353,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(8.)"
      ]
     },
     "execution_count": 353,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.norm(1,dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 354,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2., 2., 2., 2.])"
      ]
     },
     "execution_count": 354,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.norm(1,dim=0) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 356,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1.],\n",
       "         [1., 1.]],\n",
       "\n",
       "        [[1., 1.],\n",
       "         [1., 1.]]])"
      ]
     },
     "execution_count": 356,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 355,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[2., 2.],\n",
       "        [2., 2.]])"
      ]
     },
     "execution_count": 355,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c.norm(1,dim=0) # 把0维度最后会消掉"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[\n",
    "    [[1]]\n",
    "    [[1]]\n",
    "]\n",
    "\n",
    "这两个位置相加"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 排序"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 357,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = torch.arange(8).view(2,4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 362,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1, 2, 3],\n",
       "        [4, 5, 6, 7]])"
      ]
     },
     "execution_count": 362,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 358,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(7), tensor(0))"
      ]
     },
     "execution_count": 358,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.argmax(),res.argmin()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 363,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([3, 3])"
      ]
     },
     "execution_count": 363,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.argmax(dim=1) # 当然你也可以设置维度，返回的是元素的位置，不是具体的值了"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 最大值和最小值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 365,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.max(\n",
       "values=tensor([0.8874, 0.8856, 0.8725]),\n",
       "indices=tensor([1, 1, 0]))"
      ]
     },
     "execution_count": 365,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = torch.rand(2,3)\n",
    "\n",
    "test.max(dim=0) # dim=0就是消掉2，返回一个shape是3的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 367,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.8874, 0.8856, 0.8725]])"
      ]
     },
     "execution_count": 367,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.max(dim=0,keepdim=True).values # 可以发现维度还是2，3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 371,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3170, 0.2673, 0.9839, 0.4889, 0.6655],\n",
       "        [0.8522, 0.4805, 0.3176, 0.4276, 0.0217],\n",
       "        [0.2843, 0.6133, 0.7434, 0.0560, 0.9626]])"
      ]
     },
     "execution_count": 371,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# topk\n",
    "\n",
    "a = torch.rand(3,5)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 372,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.topk(\n",
       "values=tensor([[0.9839, 0.6655, 0.4889],\n",
       "        [0.8522, 0.4805, 0.4276],\n",
       "        [0.9626, 0.7434, 0.6133]]),\n",
       "indices=tensor([[2, 4, 3],\n",
       "        [0, 1, 3],\n",
       "        [4, 2, 1]]))"
      ]
     },
     "execution_count": 372,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.topk(3,dim=1,largest=False) # 找到每行的前3个最大值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 375,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.kthvalue(\n",
       "values=tensor([0.2673, 0.0217, 0.0560]),\n",
       "indices=tensor([1, 4, 3]))"
      ]
     },
     "execution_count": 375,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.kthvalue(1,dim=1) # 返回对应维度的最小值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 比较运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 376,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = torch.rand(3,4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 380,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3471, 0.7120, 0.2535, 0.0636],\n",
       "        [0.5709, 0.4324, 0.5272, 0.3198],\n",
       "        [0.3665, 0.7102, 0.4018, 0.7993]])"
      ]
     },
     "execution_count": 380,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 377,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[True, True, True, True],\n",
       "        [True, True, True, True],\n",
       "        [True, True, True, True]])"
      ]
     },
     "execution_count": 377,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 379,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[False, False, False, False],\n",
       "        [False, False, False, False],\n",
       "        [False, False, False, False]])"
      ]
     },
     "execution_count": 379,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test < 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 382,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[True, True, True, True],\n",
       "        [True, True, True, True],\n",
       "        [True, True, True, True]])"
      ]
     },
     "execution_count": 382,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test == test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## where,gather"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### where"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 415,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1],\n",
       "        [2, 3]])"
      ]
     },
     "execution_count": 415,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.arange(4).view(2,2).type(torch.LongTensor)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 416,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4, 5],\n",
       "        [6, 7]])"
      ]
     },
     "execution_count": 416,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = torch.arange(4,8).view(2,2).type(torch.LongTensor)\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 417,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4, 5],\n",
       "        [2, 3]])"
      ]
     },
     "execution_count": 417,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 看下where命令\n",
    "# where 就是 a if 1 else b\n",
    "\n",
    "torch.where(a>1, a, b) # a,b 的数据都要转化成float类型。不然会报错"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### gather"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 418,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on built-in function gather:\n",
      "\n",
      "gather(...)\n",
      "    gather(input, dim, index, out=None, sparse_grad=False) -> Tensor\n",
      "    \n",
      "    Gathers values along an axis specified by `dim`.\n",
      "    \n",
      "    For a 3-D tensor the output is specified by::\n",
      "    \n",
      "        out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0\n",
      "        out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1\n",
      "        out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2\n",
      "    \n",
      "    If :attr:`input` is an n-dimensional tensor with size\n",
      "    :math:`(x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`\n",
      "    and ``dim = i``, then :attr:`index` must be an :math:`n`-dimensional tensor with\n",
      "    size :math:`(x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})` where :math:`y \\geq 1`\n",
      "    and :attr:`out` will have the same size as :attr:`index`.\n",
      "    \n",
      "    Args:\n",
      "        input (Tensor): the source tensor\n",
      "        dim (int): the axis along which to index\n",
      "        index (LongTensor): the indices of elements to gather\n",
      "        out (Tensor, optional): the destination tensor\n",
      "        sparse_grad(bool,optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor.\n",
      "    \n",
      "    Example::\n",
      "    \n",
      "        >>> t = torch.tensor([[1,2],[3,4]])\n",
      "        >>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))\n",
      "        tensor([[ 1,  1],\n",
      "                [ 4,  3]])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# gather就是查表\n",
    "# v、https://zhuanlan.zhihu.com/p/101896024\n",
    "help(torch.gather)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 420,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4, 5],\n",
       "        [2, 3]])"
      ]
     },
     "execution_count": 420,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.randint(1,10,(2,2))\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 431,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0],\n",
       "        [1, 0]])"
      ]
     },
     "execution_count": 431,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = torch.randint(0,2,(2,2))\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 433,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4, 4],\n",
       "        [3, 2]])"
      ]
     },
     "execution_count": 433,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.gather(a, 1, b) # dim为1表示按不同列索引，0表示4，1对应3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 435,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4, 5],\n",
       "        [2, 5]])"
      ]
     },
     "execution_count": 435,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.gather(a, 0, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 437,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.1024,  0.6968,  1.0233, -0.5221,  1.0193, -0.5163, -0.3523, -0.2783,\n",
       "          0.6183, -0.2060],\n",
       "        [ 2.4978, -0.2458,  2.2695, -1.7554, -2.6681, -0.9304,  0.4501,  0.5302,\n",
       "          0.3926, -0.1445],\n",
       "        [-0.2548,  0.5078, -0.7320, -0.8752,  1.8231,  0.9859,  0.6047,  1.4945,\n",
       "          0.9100, -0.4755],\n",
       "        [ 1.1922, -0.1419, -0.5319,  1.3735, -0.7563,  0.4903, -0.7779, -0.4317,\n",
       "          0.4654, -0.3214]])"
      ]
     },
     "execution_count": 437,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 举个实际的例子\n",
    "\n",
    "a = torch.randn(4, 10) # randn生成正太分布的数据\n",
    "a # 例如神经网络对4张照片的概率输出"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 439,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.topk(\n",
       "values=tensor([[1.0233],\n",
       "        [2.4978],\n",
       "        [1.8231],\n",
       "        [1.3735]]),\n",
       "indices=tensor([[2],\n",
       "        [0],\n",
       "        [4],\n",
       "        [3]]))"
      ]
     },
     "execution_count": 439,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res = a.topk(dim=1, k=1) # dim=1表示找到不同列的最大值，也就是每一行的最大值\n",
    "res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用torch完成Minist手写数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch import optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "#加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 500\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST('mnist data', train=True, download=True, transform=torchvision.transforms.Compose(\n",
    "        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.13,),(0.3,))]\n",
    "    )),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST('mnist data', train=False, download=True, transform=torchvision.transforms.Compose(\n",
    "        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.13,),(0.3,))]\n",
    "    )),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "i,j = next(iter(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([500, 1, 28, 28])"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i.shape\n",
    "# 每个batch有500个数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([500])"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "j.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net,self).__init__()\n",
    "        # 一共有三层网络，输入是28*28\n",
    "        # 第二层是256\n",
    "        # 第三层是64\n",
    "        # 最后输出是10，因为有0-9个数字。第二层和第三层的神经元数都是自己设定的\n",
    "        self.func1 = nn.Linear(28*28, 256)\n",
    "        self.func2 = nn.Linear(256, 64)\n",
    "        self.func3 = nn.Linear(64, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        # x:[batch, 1,28,28]\n",
    "        # h1 = relu(xw1+b1)\n",
    "        x = F.relu(self.func1(x))\n",
    "        x = F.relu(self.func2(x))\n",
    "        x = self.func3(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([500, 1, 28, 28]) torch.Size([500])\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'one_hot' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-80-d94253743052>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;36m28\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m         \u001b[0my_onehot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mone_hot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmes_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_onehot\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'one_hot' is not defined"
     ]
    }
   ],
   "source": [
    "net = Net()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.01, momentum = 0.9)\n",
    "\n",
    "for epoch in range(3): # 对数据集，迭代3遍\n",
    "    for batch_index, (x,y) in enumerate(train_loader): # 开始对数据集的每个batch迭代第一遍\n",
    "        print(x.shape,y.shape)\n",
    "        x = x.view(x.size(0), 28*28)\n",
    "        out = net(x)\n",
    "        y_onehot = one_hot(y)\n",
    "        loss = F.mes_loss(out, y_onehot)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        if batch_inde % 10 == 0:\n",
    "            print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "292px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
